#!/usr/bin/env python3
# PYTHON_ARGCOMPLETE_OK
from __future__ import annotations

import sys
import os
import signal
import argparse
from time import strftime, gmtime
import logging
import random
import argcomplete # type: ignore

from datetime import date
import warnings
from cryptography.utils import CryptographyDeprecationWarning
with warnings.catch_warnings():
    if date.today() < date(2025, 1, 1):
        # Temporarily wave CryptographyDeprecationWarning while we wait on a paramiko-ng release
        warnings.filterwarnings('ignore', category=CryptographyDeprecationWarning)
    from fabric.api import local, hide, warn_only, env, execute, parallel # type: ignore

import string
from inspect import signature
from warnings import warn
from pathlib import Path

from runtools.runtime_config import RuntimeConfig

from awstools.awstools import valid_aws_configure_creds, get_aws_userid, subscribe_to_firesim_topic, awsinit
from awstools.afitools import share_agfi_in_all_regions

from buildtools.buildconfigfile import BuildConfigFile
from buildtools.bitbuilder import F1BitBuilder

from util.streamlogger import StreamLogger, InfoStreamLogger
from util.filelineswap import file_line_swap
from util.io import firesim_input

from typing import Dict, Callable, Type, Optional, TypedDict, get_type_hints, Tuple, List

PLATFORM_LIST = [_.name for _ in Path(__file__).parent.parent.joinpath('platforms').iterdir()]

class Task(TypedDict):
    task: Callable
    config: Optional[Callable]

"""Registry of tasks """
TASKS: Dict[str, Task] = {}

class FiresimTaskAccessViolation(RuntimeError):
    pass

def register_task(task: Callable) -> Callable:
    """Decorator to Map `task.__name__` to `task` in `TASKS`

    Args:
        task: callable task that either has no parameters or
              the first parameter must be typed with a class
              that accepts `argparse.Namespace` as the only parameter
              in it's constructor

    Returns:
        a stub that raises `FiresimTaskAccessViolation` to ensure that the
        task is only callable via `TASKS`
    """
    tn = task.__name__
    if tn in TASKS:
        raise KeyError(f"Task '{tn}' already registered by {TASKS[tn]}")

    config_class = None

    # resolve str type hints
    task.__annotations__ = get_type_hints(task)

    # introspect the type of config that this task takes (it's first param)
    sig = signature(task)
    if sig.parameters:
        first = next(iter(sig.parameters.values()))
        if first.annotation is first.empty:
            raise TypeError(f'{tn}, requires type annotation on first parameter, {first.name}')
        else:
            config_class = first.annotation

            # resolve str type hints
            config_class.__init__.__annotations__ = get_type_hints(config_class.__init__)

            # check that the first parameter takes a Namespace passed to its constructor
            # or that it is a Namespace
            if config_class is not argparse.Namespace:
                csig = signature(config_class)
                if csig.parameters:
                    cfirst = next(iter(csig.parameters.values()))
                    if cfirst.annotation is cfirst.empty:
                        raise TypeError(f'{tn}, first parameter, {first.name}, needs type annotation on '+
                                        f'{config_class.__name__}.__init__ first parameter, {cfirst.name}')
                    else:
                        assert cfirst.annotation is argparse.Namespace
                else:
                    raise TypeError(f'{tn}, first parameter {first.name} constructor takes no parameters? Expected argparse.Namespace.')

    TASKS[tn] = {'task': task, 'config': config_class}

    # we don't want this function to be callable any way other than via firesim.main()
    # so that we know programmatic changes to TASKS are followed (currently for testing purposes)
    # so we don't return the original callable task, we return a lambda that raises
    def inner(*args, **kwargs):
        raise FiresimTaskAccessViolation(f"Called '{tn}' without going through TASKS")
    # or actually, we define a closure with def and return that because lambdas that raise
    # are convoluted
    return inner

## below are tasks that users can call
## to add a task, add it here, using the @register_task decorator in it's definition


@register_task
def managerinit(args: argparse.Namespace):
    """ Setup local FireSim manager components. """

    if not args.platform:
        rootLogger.error('::ERROR:: managerinit requires that you provide --platform <platformname>')
        sys.exit(1)

    rootLogger.info("Backing up initial config files, if they exist.")
    config_files = ["build", "build_recipes", "hwdb", "runtime"]
    for conf_file in config_files:
        with warn_only(), hide('everything'):
            m = local("""cp config_{}.yaml sample-backup-configs/backup_config_{}.yaml""".format(conf_file, conf_file), capture=True)
            rootLogger.debug(m)
            rootLogger.debug(m.stderr)

    rootLogger.info("Creating initial config files from examples.")
    with hide('everything'):
        for conf_file in config_files:
            m = local("""cp sample-backup-configs/sample_config_{}.yaml config_{}.yaml""".format(conf_file, conf_file), capture=True)
            rootLogger.debug(m)
            rootLogger.debug(m.stderr)

    def get_indexes(lines: List[str], keyword_start: str, keyword_end: str)-> Tuple[int, int]:
        start_arg_idx = None
        end_arg_idx = None
        for i, l in enumerate(lines):
            if keyword_start in l:
                start_arg_idx = i + 1
            if keyword_end in l:
                end_arg_idx = i

        assert start_arg_idx is not None, f"""Unable to find "{keyword_start}" in input lines"""
        assert end_arg_idx is not None, f"""Unable to find "{keyword_end}" in input lines"""

        return start_arg_idx, end_arg_idx

    rootLogger.info("Adding default overrides to config files")
    if args.platform == 'f1':
        runfarm_default_file = "run-farm-recipes/aws_ec2.yaml"
        with open(runfarm_default_file, "r") as f:
            rf_recipe_lines = f.readlines()
        start_lines = [f"base_recipe: {runfarm_default_file}\n"]
        start_lines += ["recipe_arg_overrides:\n"]

        start_arg_idx, end_arg_idx = get_indexes(rf_recipe_lines, "managerinit arg start", "managerinit arg end")
        rf_recipe_lines = ["  " + l for l in start_lines] + rf_recipe_lines[start_arg_idx:end_arg_idx]

        file_line_swap(
            "config_runtime.yaml",
            "config_runtime.yaml",
            "managerinit replace start",
            "managerinit replace end",
            rf_recipe_lines)

        buildfarm_default_file = "build-farm-recipes/aws_ec2.yaml"
        with open(buildfarm_default_file, "r") as f:
            bf_recipe_lines = f.readlines()
        start_lines = [f"base_recipe: {buildfarm_default_file}\n"]
        start_lines += ["recipe_arg_overrides:\n"]

        start_arg_idx, end_arg_idx = get_indexes(bf_recipe_lines, "managerinit arg start", "managerinit arg end")
        bf_recipe_lines = ["  " + l for l in start_lines] + bf_recipe_lines[start_arg_idx:end_arg_idx]

        file_line_swap(
            "config_build.yaml",
            "config_build.yaml",
            "managerinit replace start",
            "managerinit replace end",
            bf_recipe_lines)
    elif args.platform == 'vitis' or args.platform == 'xilinx_alveo_u250' or args.platform == 'xilinx_alveo_u280' or args.platform == 'xilinx_alveo_u200' or args.platform == 'xilinx_vcu118' or args.platform == 'rhsresearch_nitefury_ii':
        runfarm_default_file = "run-farm-recipes/externally_provisioned.yaml"
        with open(runfarm_default_file, "r") as f:
            rf_recipe_lines = f.readlines()
        start_lines = [f"base_recipe: {runfarm_default_file}\n"]
        start_lines += ["recipe_arg_overrides:\n"]

        deploy_manager_map = {
            'vitis': 'VitisInstanceDeployManager',
            'xilinx_alveo_u200': 'XilinxAlveoU200InstanceDeployManager',
            'xilinx_alveo_u250': 'XilinxAlveoU250InstanceDeployManager',
            'xilinx_alveo_u280': 'XilinxAlveoU280InstanceDeployManager',
            'xilinx_vcu118': 'XilinxVCU118InstanceDeployManager',
            'rhsresearch_nitefury_ii': 'RHSResearchNitefuryIIInstanceDeployManager',
        }
        firesim_runs_dir = str(Path.home()) + "/FIRESIM_RUNS_DIR"

        start_arg_idx, end_arg_idx = get_indexes(rf_recipe_lines, "managerinit arg start", "managerinit arg end")
        rf_recipe_lines = ["  " + l for l in start_lines] + rf_recipe_lines[start_arg_idx:end_arg_idx]
        for l in range(len(rf_recipe_lines)):
            if "EC2InstanceDeployManager" in rf_recipe_lines[l]:
                rf_recipe_lines[l] = rf_recipe_lines[l].replace("EC2InstanceDeployManager", deploy_manager_map[args.platform])
            if "/home/centos" in rf_recipe_lines[l]:
                rf_recipe_lines[l] = rf_recipe_lines[l].replace("/home/centos", firesim_runs_dir)

        file_line_swap(
            "config_runtime.yaml",
            "config_runtime.yaml",
            "managerinit replace start",
            "managerinit replace end",
            rf_recipe_lines)

        buildfarm_default_file = "build-farm-recipes/externally_provisioned.yaml"
        with open(buildfarm_default_file, "r") as f:
            bf_recipe_lines = f.readlines()
        start_lines = [f"base_recipe: {buildfarm_default_file}\n"]
        start_lines += ["recipe_arg_overrides:\n"]

        start_arg_idx, end_arg_idx = get_indexes(bf_recipe_lines, "managerinit arg start", "managerinit arg end")
        bf_recipe_lines = ["  " + l for l in start_lines] + bf_recipe_lines[start_arg_idx:end_arg_idx]

        file_line_swap(
            "config_build.yaml",
            "config_build.yaml",
            "managerinit replace start",
            "managerinit replace end",
            bf_recipe_lines)
    else:
        rootLogger.info(f"Unknown platform {args.platform}. Skipping default config overrides.")

    if args.platform == 'f1':
        awsinit()

@register_task
def infrasetup(runtime_conf: RuntimeConfig) -> None:
    """ do infrasetup. """
    runtime_conf.infrasetup()

@register_task
def boot(runtime_conf: RuntimeConfig) -> None:
    """ do boot. """
    runtime_conf.boot()

@register_task
def kill(runtime_conf: RuntimeConfig) -> None:
    """ do kill. """
    runtime_conf.kill()

@register_task
def runworkload(runtime_conf: RuntimeConfig) -> None:
    """ do runworkload. """
    runtime_conf.run_workload()

@register_task
def buildbitstream(build_config_file: BuildConfigFile) -> None:
    """ Starting from local Chisel, build a bitstream for all of the specified
    hardware configs. """

    # forced to build locally
    for build_config in build_config_file.builds_list:
        execute(build_config.bitbuilder.replace_rtl, hosts=['localhost'])
        execute(build_config.bitbuilder.build_driver, hosts=['localhost'])

    def release_build_hosts_handler(sig, frame) -> None:
        """ Handler that prompts to release build farm hosts if you press ctrl-c. """
        rootLogger.info("You pressed ctrl-c, so builds have been killed.")

        userconfirm = "yes"
        if not build_config_file.forceterminate:
            userconfirm = firesim_input("Do you also want to terminate your build hosts? Type 'yes' to do so.\n")

        if userconfirm == "yes":
            build_config_file.release_build_hosts()
            rootLogger.info("Build farm hosts released.")
        else:
            rootLogger.info("Build farm release skipped. There may still be build farm hosts running.")
        sys.exit(0)

    signal.signal(signal.SIGINT, release_build_hosts_handler)

    # local items (replace_rtl) need to be called in a loop, for each config
    # remote items will map themselves
    build_config_file.request_build_hosts()

    # confirm that build hosts are in running state so that they have
    # been assigned IP addresses
    build_config_file.wait_on_build_host_initializations()

    @parallel
    def parallel_build_helper(config_file: BuildConfigFile, bypass: bool = False) -> bool:
        """Wrap parallel portion of ``build_bitstream``

        Args:
            config_file: BuildConfigFile
            bypass: If true, immediately return and terminate build host. Used for testing purposes.

        Returns:
            Boolean indicating if the build passed or not.
        """
        build_config_file = config_file.get_build_by_ip(env.host_string)
        return build_config_file.bitbuilder.build_bitstream(bypass)

    # run builds, then terminate instances
    results = execute(parallel_build_helper, build_config_file, hosts=build_config_file.build_ip_set)
    if False in results.values():
        rootLogger.critical("ERROR: A bitstream build failed.")
        sys.exit(1)

@register_task
def builddriver(runtime_conf: RuntimeConfig) -> None:
    """ Only perform the driver build (host-processor side of an FPGA sim or
    an entire metasim)."""
    runtime_conf.build_driver()

@register_task
def enumeratefpgas(runtime_conf: RuntimeConfig) -> None:
    """ For all run hosts, create a FPGA DB file (default name given by config_runtime.yaml) """
    runtime_conf.enumerate_fpgas()

@register_task
# XXX this needs to be renamed or rethought, perhaps this is a backend-specific task?
def tar2afi(build_config_file: BuildConfigFile) -> None:
    """
    Starting from the tarball generated by Vivado generate an AFI

    Iterate the build_config_file.builds_list using a multiprocessing.Pool
    running the logic for submitting the tarball to `aws ec2 create-fpga-image`,
    wait for afi to be available and copy to all known regions

    NOTE: this is only useful if you also pass the --launchtime argument to firesim
    and specify an already existing launchtime for a previous, tarball generation
    """

    for build_config in build_config_file.builds_list:
        assert isinstance(build_config.bitbuilder, F1BitBuilder)
        execute(build_config.bitbuilder.aws_create_afi, build_config_file, build_config, hosts=['localhost'])


@register_task
def runcheck(runtime_conf: RuntimeConfig) -> None:
    """ Do nothing, just let the config process run. """
    pass


@register_task
def launchrunfarm(runtime_conf: RuntimeConfig) -> None:
    """ This starts an FPGA run farm, based on the parameters listed in the
    [runfarm] section of the config file. Instances are
    a) tagged with "runfarmtag" so that the manager can find them in the future
    and does not have to track state locally
    b) the list of IPs is always used locally AFTER sorting, so there is always
    a consistent "first" instance (useful for debugging)
    """
    # short versions of config file vars
    runtime_conf.run_farm.launch_run_farm()


@register_task
def terminaterunfarm(runtime_conf: RuntimeConfig) -> None:
    """ Terminate instances in the run farm.

    This works in 2 modes:

    1) If you pass no --terminatesomeINSTANCETYPE flags, it will terminate all
       instances with the specified run farm tag.

    2) If you pass ANY --terminatesomeINSTANCETYPE flag, it will terminate only
       that many instances of the specified types and leave all others
       untouched.
    """
    runtime_conf.terminate_run_farm()

@register_task
def shareagfi(build_config_file: BuildConfigFile) -> None:
    """ Share the agfis specified in the agfis-to-share section with the users
    specified in the share-with-accounts section """
    # share the image with all users
    for agfiname in build_config_file.agfistoshare:
        agfi = build_config_file.hwdb.get_runtimehwconfig_from_name(agfiname).agfi

        userlist = build_config_file.acctids_to_sharewith
        share_agfi_in_all_regions(agfi, userlist)
        rootLogger.info("AGFI '%s': %s has been shared with the users specified in share-with-accounts",
                        agfiname, agfi)


def construct_firesim_argparser() -> argparse.ArgumentParser:
    # parse command line args
    parser = argparse.ArgumentParser(description='FireSim Simulation Manager.')
    parser.add_argument('task', type=str,
                        help='Management task to run.', choices=list(TASKS))
    parser.add_argument('-c', '--runtimeconfigfile', type=str,
                        help='Optional custom runtime/workload config file. Defaults to config_runtime.yaml.',
                        default='config_runtime.yaml')
    parser.add_argument('-b', '--buildconfigfile', type=str,
                        help='Optional custom build config file. Defaults to config_build.yaml.',
                        default='config_build.yaml')
    parser.add_argument('-r', '--buildrecipesconfigfile', type=str,
                        help='Optional custom build recipe config file. Defaults to config_build_recipes.yaml.',
                        default='config_build_recipes.yaml')
    parser.add_argument('-a', '--hwdbconfigfile', type=str,
                        help='Optional custom HW database config file. Defaults to config_hwdb.yaml.',
                        default='config_hwdb.yaml')
    parser.add_argument('-x', '--overrideconfigdata', type=str,
                        help='Override a single value from one of the the RUNTIME e.g.: --overrideconfigdata "target-config link-latency 6405".',
                        default="")
    parser.add_argument('-f', '--terminatesomef116', type=int,
                        help='DEPRECATED. Use --terminatesome=f1.16xlarge:count instead. Will be removed in the next major version of FireSim (1.15.X). Old help message: Only used by terminaterunfarm. Terminates this many of the previously launched f1.16xlarges.',
                        default=-1)
    parser.add_argument('-g', '--terminatesomef12', type=int,
                        help='DEPRECATED. Use --terminatesome=f1.2xlarge:count instead. Will be removed in the next major version of FireSim (1.15.X). Old help message: Only used by terminaterunfarm. Terminates this many of the previously launched f1.2xlarges.',
                        default=-1)
    parser.add_argument('-i', '--terminatesomef14', type=int,
                        help='DEPRECATED. Use --terminatesome=f1.4xlarge:count instead. Will be removed in the next major version of FireSim (1.15.X). Old help message: Only used by terminaterunfarm. Terminates this many of the previously launched f1.4xlarges.',
                        default=-1)
    parser.add_argument('-m', '--terminatesomem416', type=int,
                        help='DEPRECATED. Use --terminatesome=m4.16xlarge:count instead. Will be removed in the next major version of FireSim (1.15.X). Old help message: Only used by terminaterunfarm. Terminates this many of the previously launched m4.16xlarges.',
                        default=-1)
    def terminatesomesplitter(raw_arg):
        split_arg = raw_arg.split(":")
        assert len(split_arg) == 2, "Invalid terminatesome argument"
        return split_arg[0], int(split_arg[1])
    parser.add_argument('--terminatesome', action='append', type=terminatesomesplitter,
                        help='Only used by terminaterunfarm. Used to specify a restriction on how many instances to terminate. E.g., --terminatesome=f1.2xlarge:2 will terminate only 2 of the f1.2xlarge instances in the runfarm, regardless of what other instances are in the runfarm. This argument can be specified multiple times to terminate additional instance types/counts. Behavior when specifying the same instance type multiple times is undefined. This replaces the old --terminatesome{f116,f12,f14,m416} arguments. Behavior when specifying these old-style terminatesome flags and this new style flag at the same time is also undefined.')
    parser.add_argument('-q', '--forceterminate', action='store_true',
                        help='For terminaterunfarm and buildbitstream, force termination without prompting user for confirmation. Defaults to False')
    parser.add_argument('-t', '--launchtime', type=str,
                        help='Give the "Y-m-d--H-M-S" prefix of results-build directory. Useful for tar2afi when finishing a partial buildbitstream')
    parser.add_argument('--platform', type=str, choices=PLATFORM_LIST, default='f1',
                        help='Required argument for "managerinit" to specify which platform you will be using')

    argcomplete.autocomplete(parser)
    return parser

def check_env() -> None:
    """Check assumptions about environment setup

    Raises:
        `SystemExit` if any assumption is violated
    """
    # make sure that sourceme-f1 was sourced
    if os.environ.get('FIRESIM_SOURCED') is None:
        rootLogger.critical("ERROR: You must source firesim/sourceme-manager.sh!")
        sys.exit(1)


def main(args: argparse.Namespace) -> None:
    """ Main function for FireSim manager. """
    # large timeouts, retry connections
    env.timeout = 100
    env.connection_attempts = 10
    # we elastically spin instances up/down. we can easily get re-used IPs with
    # different keys. also, probably won't get MITM'd
    env.disable_known_hosts = True

    rootLogger.info("FireSim Manager. Docs: https://docs.fires.im\nRunning: %s\n", str(args.task))

    t = TASKS[args.task]
    if t['config']:
        if t['config'] is argparse.Namespace:
            t['task'](args)
        else:
            t['task'](t['config'](args))
    else:
        t['task']()


if __name__ == '__main__':
    # set the program working dir to wherever firesim is located
    # this lets you run firesim from anywhere, not necessarily firesim/deploy/
    abspath = os.path.abspath(__file__)
    dname = os.path.dirname(abspath)
    os.chdir(dname)

    # parse args BEFORE setting up logs, otherwise argcomplete will cause
    # junk log files to be created. also lets us use args.task in the logfile
    # name
    args = construct_firesim_argparser().parse_args()

    # logging setup
    def logfilename() -> str:
        """ Construct a unique log file name from: date + 16 char random. """
        timeline = strftime("%Y-%m-%d--%H-%M-%S", gmtime())
        randname = ''.join(random.choice(string.ascii_uppercase + string.digits) for _ in range(16))
        return timeline + "-" + args.task + "-" + randname + ".log"

    # treat python warnings as WARNING level messages from py.warn
    logging.captureWarnings(True)

    rootLogger = logging.getLogger()
    rootLogger.setLevel(logging.NOTSET) # capture everything

    # log to file
    full_log_filename = "logs/" + logfilename()
    fileHandler = logging.FileHandler(full_log_filename)
    # formatting for log to file
    # TODO: filehandler should be handler 0 (firesim_topology_with_passes expects this
    # to get the filename) - handle this more appropriately later
    logFormatter = logging.Formatter("%(asctime)s [%(funcName)-12.12s] [%(levelname)-5.5s]  %(message)s")
    fileHandler.setFormatter(logFormatter)
    fileHandler.setLevel(logging.NOTSET) # log everything to file
    rootLogger.addHandler(fileHandler)

    # log to stdout, without special formatting
    consoleHandler = logging.StreamHandler(stream=sys.stdout)
    consoleHandler.setLevel(logging.INFO) # show only INFO and greater in console
    rootLogger.addHandler(consoleHandler)

    # hide messages lower than warning from boto3/paramiko
    logging.getLogger('boto3').setLevel(logging.WARNING)
    logging.getLogger('botocore').setLevel(logging.WARNING)
    logging.getLogger('nose').setLevel(logging.WARNING)
    logging.getLogger("paramiko").setLevel(logging.WARNING)

    check_env()

    # lastly - we want anything printed to stdout to be converted into a DEBUG
    # level logging message and anything printed to stderr converted into INFO.
    # This is primarily because fabric does not use logging, it prints explicitly
    # to stdout and stderr.  We want it's output to be logged.
    with StreamLogger('stdout'), InfoStreamLogger('stderr'):
        exitcode = 0
        try:
            main(args)
        except:
            # log all exceptions that make it this far
            rootLogger.exception("Fatal error.")
            exitcode = 1
        finally:
            rootLogger.info("""The full log of this run is:\n{basedir}/{fulllog}""".format(basedir=dname, fulllog=full_log_filename))
            sys.exit(exitcode)
